#!/usr/bin/env python3

"""
memory_checker

This script is part of the feature which will restart the container if memory
usage of it is larger than the threshold value.

This script is used to check the memory usage of specified cotnainer and
is intended to be run by Monit. It will write an alerting message into
syslog if memory usage of the container is larger than the threshold value for X
times within Y cycles/minutes. Note that if print(...) statement in this script
was executed, the string in it will be appended to Monit syslog messages.

The following is an example in Monit configuration file to show how Monit will run
this script:

check program container_memory_<container_name> with path "/usr/bin/memory_checker <container_name> <threshold_value>"
    if status == 3 for X times within Y cycles exec "/usr/bin/restart_service <container_name>"
"""

import argparse
import os
import subprocess
import sys
import syslog
import re
import time

import docker

from swsscommon import swsscommon

EVENTS_PUBLISHER_SOURCE = "sonic-events-host"
EVENTS_PUBLISHER_TAG = "mem-threshold-exceeded"

CGROUP_DOCKER_MEMORY_DIR = "/sys/fs/cgroup/memory/docker/"

# Define common error codes
ERROR_CONTAINER_ID_NOT_FOUND = "[memory_checker] Failed to get container ID of '{}'! Exiting ..."
ERROR_CGROUP_MEMORY_USAGE_NOT_FOUND = "[memory_checker] cgroup memory usage file '{}' of container '{}' does not exist on device! Exiting ..."
ERROR_CONTAINER_MEMORY_USAGE_NOT_FOUND = "[memory_checker] Failed to get the memory usage of container '{}'! Exiting ..."
ERROR_CONTAINER_CACHE_USAGE_NOT_FOUND = "[memory_checker] Failed to get the cache usage of container '{}'! Exiting ..."
ERROR_CGROUP_MEMORY_STATS_NOT_FOUND = "[memory_checker] cgroup memory statistics file '{}' of container '{}' does not exist on device! Exiting ..."
ERROR_CGROUP_MEMORY_STATS_LINE_FORMAT = "[memory_checker] cgroup memory statistics file '{}' of container '{}' has invalid line format! Exiting ..."

# Define common exit codes
CONTAINER_NOT_RUNNING = 0
INTERNAL_ERROR = 1
INVALID_VALUE = 2
EXCEED_THRESHOLD = 3

def validate_container_id(container_id):
    pattern = r'^[a-zA-Z0-9]+$'

    if not re.match(pattern, container_id):
        syslog.syslog(syslog.LOG_ERR, "Invalid container_id: {}".format(container_id))
        sys.exit(INTERNAL_ERROR)

def get_command_result(command):
    """Executes the command and return the resulting output.

    Args:
        command: A string contains the command to be executed.

    Returns:
        A string which contains the output of command.
    """
    command_stdout = ""

    try:
        proc_instance = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                         universal_newlines=True)
        command_stdout, command_stderr = proc_instance.communicate()
        if proc_instance.returncode != 0:
            syslog.syslog(syslog.LOG_ERR, "[memory_checker] Failed to execute the command '{}'. Return code: '{}'"
                          .format(' '.join(command), proc_instance.returncode))
            sys.exit(INTERNAL_ERROR)
    except (OSError, ValueError) as err:
        syslog.syslog(syslog.LOG_ERR, "[memory_checker] Failed to execute the command '{}'. Error: '{}'"
                      .format(' '.join(command), err))
        sys.exit(INTERNAL_ERROR)

    return command_stdout.strip()

def get_container_id(container_name):
    """Gets full container ID of the specified container
    Args:
        container_name: A string indicates the name of specified container.
    Returns:
        container_id: A string indicates the full ID of specified container.
    """
    container_id = ""

    get_container_info_cmd = ["docker", "ps", "--no-trunc", "--filter", "name={}".format(container_name)]

    command_stdout = get_command_result(get_container_info_cmd)

    for line in command_stdout.splitlines():
        if container_name in line:
            container_id = line.split()[0].strip()
            break

    if not container_id:
        syslog.syslog(syslog.LOG_ERR, ERROR_CONTAINER_ID_NOT_FOUND.format(container_name))

        sys.exit(INTERNAL_ERROR)

    return container_id

def get_memory_usage(container_id):
    """Reads the container's memory usage from the control group subsystem's file
    '/sys/fs/cgroup/memory/docker/<container_id>/memory.usage_in_bytes'.
    Args:
        container_id: A string indicates the full ID of a container.
    Returns:
        A string indicates memory usage (Bytes) of a container.
    """
    validate_container_id(container_id)

    docker_memory_usage_file_path = CGROUP_DOCKER_MEMORY_DIR + container_id + "/memory.usage_in_bytes"

    for attempt in range(3):
        try:
            with open(docker_memory_usage_file_path, 'r') as file:
                return file.read().strip()
        except FileNotFoundError:
            if attempt < 2:
                time.sleep(0.1)  # retry after short delay
            else:
                break
        except IOError:
            syslog.syslog(syslog.LOG_ERR, ERROR_CONTAINER_MEMORY_USAGE_NOT_FOUND.format(container_id))
            sys.exit(INTERNAL_ERROR)

    syslog.syslog(syslog.LOG_ERR, ERROR_CGROUP_MEMORY_USAGE_NOT_FOUND.format(docker_memory_usage_file_path, container_id))
    sys.exit(INTERNAL_ERROR)

def get_inactive_cache_usage(container_id):
    """Reads the container's cache usage from the field 'total_inactive_file' in control
    group subsystem's file '/sys/fs/cgroup/memory/docker/<container_id>/memory.stat'.
    Args:
        container_id: A string indicates the full ID of a container.
    Returns:
        cache_usage_in_bytes: A string indicates the cache usage (Bytes) of a container.
    """
    cache_usage_in_bytes = ""

    validate_container_id(container_id)

    docker_memory_stat_file_path = CGROUP_DOCKER_MEMORY_DIR + container_id + "/memory.stat"
    if not os.path.exists(docker_memory_stat_file_path):
        syslog.syslog(syslog.LOG_ERR, ERROR_CGROUP_MEMORY_STATS_NOT_FOUND.format(docker_memory_stat_file_path, container_id))
        sys.exit(INTERNAL_ERROR)

    try:
        with open(docker_memory_stat_file_path, 'r') as file:
            for line in file:
                if "total_inactive_file" in line:
                    split_line = line.split()
                    if len(split_line) >= 2:
                        cache_usage_in_bytes = split_line[1].strip()
                    else:
                        syslog.syslog(syslog.LOG_ERR, ERROR_CGROUP_MEMORY_STATS_LINE_FORMAT.format(docker_memory_stat_file_path, container_id))
                        sys.exit(INTERNAL_ERROR)
                    break
    except IOError as err:
        syslog.syslog(syslog.LOG_ERR, ERROR_CONTAINER_CACHE_USAGE_NOT_FOUND.format(container_id))
        sys.exit(INTERNAL_ERROR)

    return cache_usage_in_bytes

def publish_events(container_name, mem_usage_bytes, threshold_value):
    events_handle = swsscommon.events_init_publisher(EVENTS_PUBLISHER_SOURCE)
    params = swsscommon.FieldValueMap()
    params["ctr_name"] = container_name
    params["mem_usage"] = mem_usage_bytes
    params["threshold"] = threshold_value
    swsscommon.event_publish(events_handle, EVENTS_PUBLISHER_TAG, params)
    swsscommon.events_deinit_publisher(events_handle)


def check_memory_usage(container_name, threshold_value):
    """Checks the memory usage of a container from its cgroup subsystem and writes an alerting
    messages into the syslog if the memory usage is larger than the threshold value.

    Args:
        container_name: A string represtents name of a container
        threshold_value: An integer indicates the threshold value (Bytes) of memory usage.

    Returns:
        None.
    """
    if not isinstance(threshold_value, int) or threshold_value <= 0:
        syslog.syslog(syslog.LOG_ERR, "[memory_checker] Invalid threshold value! Threshold value should be a positive integer.")
        sys.exit(INVALID_VALUE)

    container_id = get_container_id(container_name)
    syslog.syslog(syslog.LOG_INFO, "[memory_checker] Container ID of '{}' is: '{}'."
                  .format(container_name, container_id))

    memory_usage_in_bytes = get_memory_usage(container_id)
    syslog.syslog(syslog.LOG_INFO, "[memory_checker] The memory usage of container '{}' is '{}' Bytes!"
                  .format(container_name, memory_usage_in_bytes))

    cache_usage_in_bytes = get_inactive_cache_usage(container_id)
    syslog.syslog(syslog.LOG_INFO, "[memory_checker] The cache usage of container '{}' is '{}' Bytes!"
                  .format(container_name, cache_usage_in_bytes))

    try:
        memory_usage = int(memory_usage_in_bytes)
        cache_usage = int(cache_usage_in_bytes)
    except ValueError as err:
        syslog.syslog(syslog.LOG_ERR, "[memory_checker] Failed to convert the memory or cache usage in string to integer! Exiting ...")
        sys.exit(INVALID_VALUE)

    total_memory_usage = memory_usage - cache_usage
    syslog.syslog(syslog.LOG_INFO, "[memory_checker] Total memory usage of container '{}' is '{}' Bytes!"
                  .format(container_name, total_memory_usage))

    if total_memory_usage > threshold_value:
        print("[{}]: Memory usage ({} Bytes) is larger than the threshold ({} Bytes)!"
              .format(container_name, total_memory_usage, threshold_value))
        publish_events(container_name, "{:.2f}".format(total_memory_usage), str(threshold_value))
        sys.exit(EXCEED_THRESHOLD)

def is_service_active(service_name):
    """Test if service is running.

    Args:
        service_name: A string contains the service name

    Returns:
        True if service is running, False otherwise
    """
    status = subprocess.run(["systemctl", "is-active", "--quiet", service_name])
    return status.returncode == 0


def get_running_container_names():
    """Retrieves names of running containers by talking to the docker daemon.

    Args:
        None.

    Returns:
        running_container_names: A list indicates names of running containers.
    """
    try:
        docker_client = docker.DockerClient(base_url='unix://var/run/docker.sock')
        running_container_list = docker_client.containers.list(filters={"status": "running"})
        running_container_names = [ container.name for container in running_container_list ]
    except (docker.errors.APIError, docker.errors.DockerException) as err:
        if not is_service_active("docker"):
            syslog.syslog(syslog.LOG_INFO,
                          "[memory_checker] Docker service is not running. Error message is: '{}'".format(err))
            return []

        syslog.syslog(syslog.LOG_ERR,
                      "Failed to retrieve the running container list from docker daemon! Error message is: '{}'"
                      .format(err))
        sys.exit(INTERNAL_ERROR)

    return running_container_names


def main():
    parser = argparse.ArgumentParser(description="Check memory usage of a container \
            and an alerting message will be written into syslog if memory usage \
            is larger than the threshold value", usage="/usr/bin/memory_checker <container_name> <threshold_value_in_bytes>")
    parser.add_argument("container_name", help="container name")
    # TODO: Currently the threshold value is hard coded as a command line argument and will
    # remove this in the new version since we want to read this value from 'CONFIG_DB'.
    parser.add_argument("threshold_value", type=int, help="threshold value in bytes")
    args = parser.parse_args()

    if not is_service_active("docker"):
        syslog.syslog(syslog.LOG_INFO,
                      "[memory_checker] Exits without checking memory usage of container '{}' since docker daemon is not running!"
                      .format(args.container_name))
        sys.exit(CONTAINER_NOT_RUNNING)

    running_container_names = get_running_container_names()
    if args.container_name in running_container_names:
        check_memory_usage(args.container_name, args.threshold_value)
    else:
        syslog.syslog(syslog.LOG_INFO,
                      "[memory_checker] Exits without checking memory usage since container '{}' is not running!"
                      .format(args.container_name))


if __name__ == "__main__":
    main()
